{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*  \n",
    "*Licensed under the MIT License.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Named Entity Recognition Using BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Before You Start\n",
    "\n",
    "The running time shown in this notebook is on a Standard_NC6 Azure Deep Learning Virtual Machine with 1 NVIDIA Tesla K80 GPU. \n",
    "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
    "\n",
    "The table below provides some reference running time on different machine configurations.  \n",
    "\n",
    "|QUICK_RUN|Machine Configurations|Running time|\n",
    "|:---------|:----------------------|:------------|\n",
    "|True|4 **CPU**s, 14GB memory| ~ 2 minutes|\n",
    "|False|4 **CPU**s, 14GB memory| ~1.5 hours|\n",
    "|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute|\n",
    "|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 7 minutes |\n",
    "\n",
    "If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
    "QUICK_RUN = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "This notebook demonstrates how to fine tune [pretrained BERT model](https://github.com/huggingface/pytorch-pretrained-BERT) for named entity recognition (NER) task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, and model evaluation. \n",
    "\n",
    "[BERT (Bidirectional Transformers for Language Understanding)](https://arxiv.org/pdf/1810.04805.pdf) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.  \n",
    "The figure below illustrates how BERT can be fine tuned for NER tasks. The input data is a list of tokens representing a sentence. In the training data, each token has an entity label. After fine tuning, the model predicts an entity label for each token in a given testing sentence. \n",
    "\n",
    "<img src=\"https://nlpbp.blob.core.windows.net/images/bert_architecture.png\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import random\n",
    "import scrapbook as sb\n",
    "from seqeval.metrics import classification_report\n",
    "\n",
    "import torch\n",
    "\n",
    "nlp_path = os.path.abspath('../../')\n",
    "if nlp_path not in sys.path:\n",
    "    sys.path.insert(0, nlp_path)\n",
    "\n",
    "from utils_nlp.models.bert.token_classification import BERTTokenClassifier, create_label_map, postprocess_token_labels\n",
    "from utils_nlp.models.bert.common import Language, Tokenizer\n",
    "from utils_nlp.dataset.wikigold import load_train_test_dfs, get_unique_labels\n",
    "from utils_nlp.common.timer import Timer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true,
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "TRAIN_DATA_FRACTION = 1\n",
    "TEST_DATA_FRACTION = 1\n",
    "NUM_TRAIN_EPOCHS = 5\n",
    "\n",
    "if QUICK_RUN:\n",
    "    TRAIN_DATA_FRACTION = 0.1\n",
    "    TEST_DATA_FRACTION = 0.1\n",
    "    NUM_TRAIN_EPOCHS = 1\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    BATCH_SIZE = 16\n",
    "else:\n",
    "    BATCH_SIZE = 8\n",
    "\n",
    "CACHE_DIR=\"./temp\"\n",
    "\n",
    "# set random seeds\n",
    "RANDOM_SEED = 100\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "# model configurations\n",
    "LANGUAGE = Language.ENGLISHCASED\n",
    "DO_LOWER_CASE = False\n",
    "MAX_SEQ_LENGTH = 200\n",
    "\n",
    "# optimizer configuration\n",
    "LEARNING_RATE = 3e-5\n",
    "\n",
    "# data configurations\n",
    "TEXT_COL = \"sentence\"\n",
    "LABELS_COL = \"labels\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get training and testing data\n",
    "The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). \n",
    "\n",
    "The helper function `load_train_test_dfs` downloads the data file if it doesn't exist in `local_cache_path`. It splits the dataset into training and testing sets according to `test_fraction`. Because this is a relatively small dataset, we set `test_fraction` to 0.5 in order to have enough data for model evaluation. Running this notebook multiple times with different random seeds produces similar results.   \n",
    "\n",
    "The helper function `get_unique_labels` returns the unique entity labels in the dataset. There are 5 unique labels in the   original dataset: 'O' (non-entity), 'I-LOC' (location), 'I-MISC' (miscellaneous), 'I-PER' (person), and 'I-ORG' (organization). \n",
    "\n",
    "The maximum number of words in a sentence is 144, so we set MAX_SEQ_LENGTH to 200 above, because the number of tokens will grow after WordPiece tokenization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum sequence length in the  data is: 144\n",
      "\n",
      "Unique entity labels: \n",
      "['O', 'I-LOC', 'I-MISC', 'I-PER', 'I-ORG']\n",
      "\n",
      "Sample sentence: \n",
      "['Two', ',', 'Samsung', 'based', ',', 'electronic', 'cash', 'registers', 'were', 'reconstructed', 'in', 'order', 'to', 'expand', 'their', 'functions', 'and', 'adapt', 'them', 'for', 'networking', '.']\n",
      "\n",
      "Sample sentence labels: \n",
      "['O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_df, test_df = load_train_test_dfs(local_cache_path=CACHE_DIR, test_fraction=0.5,random_seed=RANDOM_SEED)\n",
    "label_list = get_unique_labels()\n",
    "print('\\nUnique entity labels: \\n{}\\n'.format(label_list))\n",
    "print('Sample sentence: \\n{}\\n'.format(train_df[TEXT_COL][0]))\n",
    "print('Sample sentence labels: \\n{}\\n'.format(train_df[LABELS_COL][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = train_df.sample(frac=TRAIN_DATA_FRACTION).reset_index(drop=True)\n",
    "test_df = test_df.sample(frac=TEST_DATA_FRACTION).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note that the input text are lists of words instead of raw sentences. This format ensures matching between input words and token labels when the words are further tokenized by Tokenizer.tokenize_ner.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tokenization and Preprocessing\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Create a dictionary that maps labels to numerical values**  \n",
    "Note there is an argument called `trailing_piece_tag`. BERT uses a WordPiece tokenizer which breaks down some words into multiple tokens, e.g. \"criticize\" is tokenized into \"critic\" and \"##ize\". Since the input data only come with one token label for \"criticize\", within Tokenizer.prerocess_ner_tokens, the original token label is assigned to the first token \"critic\" and the second token \"##ize\" is labeled as \"X\". By default, `trailing_piece_tag` is set to \"X\". If \"X\" already exists in your data, you can set `trailing_piece_tag` to another value that doesn't exist in your data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "label_map = create_label_map(label_list, trailing_piece_tag=\"X\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Create a tokenizer**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer(language=LANGUAGE, \n",
    "                      to_lower=DO_LOWER_CASE, \n",
    "                      cache_dir=CACHE_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Tokenize and preprocess text**  \n",
    "The `tokenize_ner` method of the `Tokenizer` class converts text and labels in strings to numerical features, involving the following steps:\n",
    "1. WordPiece tokenization.\n",
    "2. Convert tokens and labels to numerical values, i.e. token ids and label ids.\n",
    "3. Sequence padding or truncation according to the `max_seq_length` configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \\\n",
    "    tokenizer.tokenize_ner(text=train_df[TEXT_COL],\n",
    "                           label_map=label_map,\n",
    "                           max_len=MAX_SEQ_LENGTH,\n",
    "                           labels=train_df[LABELS_COL],\n",
    "                           trailing_piece_tag=\"X\")\n",
    "test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \\\n",
    "    tokenizer.tokenize_ner(text=test_df[TEXT_COL],\n",
    "                           label_map=label_map,\n",
    "                           max_len=MAX_SEQ_LENGTH,\n",
    "                           labels=test_df[LABELS_COL],\n",
    "                           trailing_piece_tag=\"X\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Tokenizer.tokenize_ner` outputs three or four lists of numerical features lists, each sublist contains features of an input sentence: \n",
    "1. token ids: list of numerical values each corresponds to a token.\n",
    "2. attention mask: list of 1s and 0s, 1 for input tokens and 0 for padded tokens, so that padded tokens are not attended to. \n",
    "3. trailing word piece mask: boolean list, `True` for the first word piece of each original word, `False` for the trailing word pieces, e.g. ##ize. This mask is useful for removing predictions on trailing word pieces, so that each original word in the input text has a unique predicted label. \n",
    "4. label ids: list of numerical values each corresponds to an entity label, if `labels` is provided."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample token ids:\n",
      "[142, 3169, 9421, 24164, 112, 188, 1927, 1646, 2269, 16372, 2500, 1127, 2888, 1118, 24164, 118, 21452, 3991, 3066, 1107, 1898, 1196, 1217, 3306, 1118, 1203, 3991, 113, 23266, 114, 1105, 2028, 3733, 1154, 1103, 11594, 25702, 1162, 4097, 119, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
      "\n",
      "Sample attention mask:\n",
      "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
      "\n",
      "Sample trailing token mask:\n",
      "[True, False, False, True, True, False, True, True, True, True, True, True, True, True, True, False, False, False, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
      "\n",
      "Sample label ids:\n",
      "[4, 5, 5, 4, 0, 5, 0, 1, 0, 0, 0, 0, 0, 0, 4, 5, 5, 5, 4, 0, 0, 0, 0, 0, 0, 4, 5, 0, 4, 0, 0, 0, 0, 0, 0, 4, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"Sample token ids:\\n{}\\n\".format(train_token_ids[0]))\n",
    "print(\"Sample attention mask:\\n{}\\n\".format(train_input_mask[0]))\n",
    "print(\"Sample trailing token mask:\\n{}\\n\".format(train_trailing_token_mask[0]))\n",
    "print(\"Sample label ids:\\n{}\\n\".format(train_label_ids[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Token Classifier\n",
    "The value of the `language` argument determines which BERT model is used:\n",
    "* Language.ENGLISH: \"bert-base-uncased\"\n",
    "* Language.ENGLISHCASED: \"bert-base-cased\"\n",
    "* Language.ENGLISHLARGE: \"bert-large-uncased\"\n",
    "* Language.ENGLISHLARGECASED: \"bert-large-cased\"\n",
    "* Language.CHINESE: \"bert-base-chinese\"\n",
    "* Language.MULTILINGUAL: \"bert-base-multilingual-cased\"\n",
    "* Language.ENGLISHLARGEWWM: \"bert-large-uncased-whole-word-masking\"\n",
    "* Language.ENGLISHLARGECASEDWWM: \"bert-large-cased-whole-word-masking\"\n",
    "\n",
    "Here we use the base, cased pretrained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "token_classifier = BERTTokenClassifier(language=LANGUAGE,\n",
    "                                       num_labels=len(label_map),\n",
    "                                       cache_dir=CACHE_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "t_total value of -1 results in schedule not being applied\n",
      "Epoch:   0%|          | 0/5 [00:00<?, ?it/s]\n",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Only 1 CUDA device is available. Data parallelism is not possible.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Iteration:  40%|███▉      | 23/58 [00:30<00:46,  1.31s/it]\u001b[A\n",
      "Iteration:  40%|███▉      | 23/58 [00:49<00:46,  1.31s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:00<00:15,  1.31s/it]\u001b[A\n",
      "Epoch:  20%|██        | 1/5 [01:16<05:04, 76.21s/it]1s/it]\u001b[A\n",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.44944492534830655\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Iteration:  40%|███▉      | 23/58 [00:30<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  40%|███▉      | 23/58 [00:43<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:00<00:15,  1.32s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:13<00:15,  1.32s/it]\u001b[A\n",
      "Epoch:  40%|████      | 2/5 [02:32<03:48, 76.30s/it]1s/it]\u001b[A\n",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.0985172498200474\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Iteration:  40%|███▉      | 23/58 [00:30<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  40%|███▉      | 23/58 [00:47<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:00<00:15,  1.32s/it]\u001b[A\n",
      "Epoch:  60%|██████    | 3/5 [03:49<02:32, 76.39s/it]1s/it]\u001b[A\n",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.04283706506649996\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Iteration:  40%|███▉      | 23/58 [00:30<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  40%|███▉      | 23/58 [00:40<00:46,  1.32s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:00<00:15,  1.32s/it]\u001b[A\n",
      "Epoch:  80%|████████  | 4/5 [05:06<01:16, 76.48s/it]1s/it]\u001b[A\n",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.021370104799079227\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Iteration:  40%|███▉      | 23/58 [00:30<00:46,  1.33s/it]\u001b[A\n",
      "Iteration:  40%|███▉      | 23/58 [00:44<00:46,  1.33s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:00<00:15,  1.33s/it]\u001b[A\n",
      "Iteration:  79%|███████▉  | 46/58 [01:14<00:15,  1.33s/it]\u001b[A\n",
      "Epoch: 100%|██████████| 5/5 [06:22<00:00, 76.57s/it]2s/it]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.01407645793695902\n",
      "Training time : 0.107 hrs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    token_classifier.fit(token_ids=train_token_ids, \n",
    "                         input_mask=train_input_mask, \n",
    "                         labels=train_label_ids,\n",
    "                         num_epochs=NUM_TRAIN_EPOCHS, \n",
    "                         batch_size=BATCH_SIZE, \n",
    "                         learning_rate=LEARNING_RATE)\n",
    "print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict on Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Iteration:   0%|          | 0/58 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Only 1 CUDA device is available. Data parallelism is not possible.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration: 100%|██████████| 58/58 [00:25<00:00,  2.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation loss: 0.10695629087597902\n",
      "Prediction time : 0.007 hrs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    pred_label_ids = token_classifier.predict(token_ids=test_token_ids, \n",
    "                                              input_mask=test_input_mask, \n",
    "                                              labels=test_label_ids, \n",
    "                                              batch_size=BATCH_SIZE)\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Model\n",
    "The `predict` method of the token classifier outputs label ids for all tokens, including the padded tokens. `postprocess_token_labels` is a helper function that removes the predictions on padded tokens. If a `label_map` is provided, it maps the numerical label ids back to original token labels which are usually string type. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           precision    recall  f1-score   support\n",
      "\n",
      "        X       0.97      0.98      0.97      1983\n",
      "     MISC       0.69      0.86      0.77       396\n",
      "      ORG       0.81      0.77      0.79       538\n",
      "      PER       0.95      0.91      0.93       550\n",
      "      LOC       0.85      0.91      0.88       543\n",
      "\n",
      "micro avg       0.90      0.92      0.91      4010\n",
      "macro avg       0.90      0.92      0.91      4010\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_tags_no_padding = postprocess_token_labels(pred_label_ids, \n",
    "                                                test_input_mask, \n",
    "                                                label_map)\n",
    "true_tags_no_padding = postprocess_token_labels(test_label_ids, \n",
    "                                                test_input_mask, \n",
    "                                                label_map)\n",
    "report_no_padding = classification_report(true_tags_no_padding, \n",
    "                                          pred_tags_no_padding, \n",
    "                                          digits=2)\n",
    "print(report_no_padding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`postprocess_token_labels` also provides an option to remove the predictions on trailing word pieces, e.g. ##ize, so that the final predicted labels correspond to the original words in the input text. The `trailing_token_mask` is obtained from `tokenizer.tokenize_ner`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           precision    recall  f1-score   support\n",
      "\n",
      "      ORG       0.78      0.77      0.77       442\n",
      "     MISC       0.68      0.84      0.75       344\n",
      "      LOC       0.85      0.91      0.88       503\n",
      "      PER       0.93      0.90      0.91       455\n",
      "\n",
      "micro avg       0.79      0.86      0.82      1744\n",
      "macro avg       0.82      0.86      0.83      1744\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred_tags_no_padding_no_trailing = postprocess_token_labels(pred_label_ids, \n",
    "                                                            test_input_mask, \n",
    "                                                            label_map, \n",
    "                                                            remove_trailing_word_pieces=True, \n",
    "                                                            trailing_token_mask=test_trailing_token_mask)\n",
    "true_tags_no_padding_no_trailing = postprocess_token_labels(test_label_ids, \n",
    "                                                            test_input_mask, \n",
    "                                                            label_map, \n",
    "                                                            remove_trailing_word_pieces=True, \n",
    "                                                            trailing_token_mask=test_trailing_token_mask)\n",
    "report_no_padding_no_trailing = classification_report(true_tags_no_padding_no_trailing, \n",
    "                                                      pred_tags_no_padding_no_trailing, \n",
    "                                                      digits=2)\n",
    "print(report_no_padding_no_trailing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the metrics are worse after excluding trailing word pieces, because they are easy to predict. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "By fine-tuning the pre-trained BERT model for token classification, we achieved significantly better results compared to the [original paper on the wikigold dataset](https://www.aclweb.org/anthology/W09-3302) with a much smaller training dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.9,
       "encoder": "json",
       "name": "precision_1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "precision_1"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.92,
       "encoder": "json",
       "name": "recall_1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "recall_1"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.91,
       "encoder": "json",
       "name": "f1_1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "f1_1"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.82,
       "encoder": "json",
       "name": "precision_2",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "precision_2"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.86,
       "encoder": "json",
       "name": "recall_2",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "recall_2"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.83,
       "encoder": "json",
       "name": "f1_2",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "f1_2"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for testing\n",
    "report_no_padding_split = report_no_padding.split('\\n')[-2].split()\n",
    "report_no_padding_no_trailing_split = report_no_padding_no_trailing.split('\\n')[-2].split()\n",
    "\n",
    "sb.glue(\"precision_1\", float(report_no_padding_split[2]))\n",
    "sb.glue(\"recall_1\", float(report_no_padding_split[3]))\n",
    "sb.glue(\"f1_1\", float(report_no_padding_split[4]))\n",
    "sb.glue(\"precision_2\", float(report_no_padding_no_trailing_split[2]))\n",
    "sb.glue(\"recall_2\", float(report_no_padding_no_trailing_split[3]))\n",
    "sb.glue(\"f1_2\", float(report_no_padding_no_trailing_split[4]))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "nlp_gpu",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
